Add batch inferencing support for GPT2LMHeadModel#7552
Add batch inferencing support for GPT2LMHeadModel#7552patrickvonplaten merged 3 commits intohuggingface:masterfrom
Conversation
|
This enables significantly faster generation.
# following above code
data = sentences * 128 # total 256 sentences
model.cuda();
data = [' '.join([x]*10) for x in data] # make the prompt longer to be more realistic
from tqdm.auto import tqdm
def test(batchsize = 1, max_gen_len = 20):
for i in tqdm(range(0, len(data), batchsize)):
batch = data[i: i+batchsize]
inputs = tokenizer(batch, return_tensors="pt", padding=True)
output_sequences = model.generate(
input_ids=inputs['input_ids'].to(model.device),
attention_mask=inputs['attention_mask'].to(model.device),
do_sample=False, # disable sampling to test if batching affects output
pad_token_id=tokenizer.eos_token_id,
max_length=len(inputs['input_ids'][0]) + max_gen_len, # let it generate longer
)
outputs = [tokenizer.decode(x) for x in output_sequences]
%time test(1, 20)
%time test(32, 20)
%time test(1, 100)
%time test(32, 100) |
|
Hey @cccntu - this is a great addition! I very much like your appraoch here. With the current implementation, the user would not be able to define his own
@LysandreJik - this feature was heavily requested by the community (linked a couple of issues below) and I think this is a great way to handle GPT2 batch generation. What do you think? |
|
@cccntu - Great work on this PR! If this PR is merged and you want to help the community a tiny bit more, you could give a short description (similar to what you've done above) on how to do batch generation with GPT2 here: https://discuss.huggingface.co/t/batch-generation-with-gpt2/1517. Many people have been asking for this so they would be very glad to see a short forum post about it. Thanks a lot again! |
| position_ids = kwargs.get("position_ids", None) | ||
|
|
||
| if attention_mask is not None and position_ids is None: | ||
| # create postion_ids on the fly for batch generation | ||
| position_ids = attention_mask.long().cumsum(-1) - 1 | ||
| position_ids.masked_fill_(attention_mask == 0, 1) | ||
| if past: | ||
| position_ids = position_ids[:, -1].unsqueeze(-1) | ||
| else: | ||
| position_ids = None |
There was a problem hiding this comment.
@patrickvonplaten
Now that you add
position_ids = kwargs.get("position_ids", None)
I think we can get rid of
else: position_ids = None
Also inspired by this related PR #7355, I think we should move all the if past together, just above return
Should I add another commit?
There was a problem hiding this comment.
No strong opinions on this, will let @patrickvonplaten decide to merge with or without this
There was a problem hiding this comment.
@cccntu - yeah I thought about this as well. The problem with this and PR #7355 and passing position_ids is that we would have to incrementally add new tokens to position_ids in generate() which would be pretty hacky since not all models support position_ids => so I'd rather not do this before doing a bigger refactor of generate, see: #6949 (will continue on the bigger refactor soon).
We can always change that later without breaking backwards compatibility.
LysandreJik
left a comment
There was a problem hiding this comment.
This is great, very simple implementation! Thanks a lot @cccntu.
|
Awesome, great work @cccntu ! It would be amazing if you could write a little description of how your PR works on the forum: https://discuss.huggingface.co/t/batch-generation-with-gpt2/1517 - the community would be very thankful I think :-) |
|
@patrickvonplaten Thanks for the suggestions! I just added some description to the forum post. 😃 link to the post for future reference: https://discuss.huggingface.co/t/batch-generation-with-gpt2/1517/2 |
|
Can you please add batch inferencing for GPT2DoubleHeadsModel too? |
|
I can see how batch generation is now available. I was wondering, if there's already a way to do the same but with different arguments of |
|
Hi @spate141, Did you mean passing a Actually, the main issue is here: We need the right-most logits not be padding, and without modifying generation_utils.py, we need to use left-padding, and consequently we need this PR to make sure the positional embedding is correct.
You can also checkout the discussions in #3021, or the forum post: https://discuss.huggingface.co/t/batch-generation-with-gpt2/1517/3 |
I saw the code and I can see why it will fail. #3021 seems informative, I'll take a look. Meanwhile I found this way to get what I mentioned:
OR
OR
@cccntu In your 2nd comment to this pull request, you posted some impressive results on why doing batch_generation is ideal, specially let's say when you have a GPU. I'm just trying to figure out if doing the same in my case is worth the latency when I have to do some post-processing. I'll post some latency results once I have this setup ready. |
|
Update: @cccntu I went with my 1st approach where I'm generating text for all texts in a single batch with global min, max values. In most cases where my last text chunk in batch is smaller meaning its min/max values are smaller than rest of text chunks in a same batch; I'm just trimming tokens. Results are impressive so far. Some numbers just in case someone stumble upon this thread in future: Fixed size text batches:
Variable size text batches:
Overall, batch text generation seems very useful(🎉) despite one has to add some overhead on top to manage some use cases. |
|
@cccntu Thanks for your great work! I stumbled upon this thread and would like to know:
|
|
Thanks for the code! I wonder if now I could generate sentences in a batch withother models (BertGeneration, for instance)? Looking forward to your reply! |
|
@cccntu Thanks for your code. By using the correct position_id in this case, we can do batch inference in pytorch model now. But when we export the gpt2 model to onnx with onnx_config = GPT2OnnxConfig(model.config)
## or using past_key_values mode
# onnx_config = GPT2OnnxConfig(model.config, use_past=True)Then the onnx model inputs don't contation position_id but only input_ids nand attention_masks。 |
|
Thank you for the code. I wonder if you have tested whether there is performance drop when using batch generation? Especially when the GPT-2 model is finetuned with right-padded data. |


What does this PR do?
This adds correct (absolute) positional embedding to the output, when given attention mask. The positional embedding is calculated using attention mask.
Fixes #3021
Here is an example usage:
outputs:
comment:
examples/text-generation/run_generation.py, but I don't know much about other models, and it (code) would be weird if only gpt2 supports batch inferencing.albert, bert, GPT2, XLM: @LysandreJik
TextGeneration: @TevenLeScao
documentation: @sgugger
@patrickvonplaten